from bert_score import BERTScorer
import numpy as np
import json
from tqdm import tqdm
import argparse
import os

scorer = BERTScorer(lang="en", rescale_with_baseline=True)

def compute_bertscore_matrix(candidates):
    """
    Compute pairwise BERTScore matrix for all candidates.
    :param candidates: List of candidate texts.
    :return: A numpy array of shape (n, n) containing pairwise BERTScore F1 scores.
    """
    n = len(candidates)
    bertscore_matrix = np.zeros((n, n))

    # Compute BERTScore for all pairs at once
    for i in range(n):
        # Compare candidate i with all other candidates
        P, R, F1 = scorer.score([candidates[i]] * n, candidates)
        bertscore_matrix[i, :] = F1.numpy()  

    # Set diagonal to 1.0 (BERTScore of a candidate with itself is 1)
    np.fill_diagonal(bertscore_matrix, 1.0)

    return bertscore_matrix

def minimum_bayes_risk(candidates):
    """
    Perform Minimum Bayes Risk decoding using BERTScore.
    :param candidates: List of candidate texts.
    :return: The selected candidate with the lowest expected risk.
    """
    n = len(candidates)
    bertscore_matrix = compute_bertscore_matrix(candidates)

    # Compute expected risk for each candidate
    expected_risk = np.mean(1 - bertscore_matrix, axis=1)

    # Select the candidate with the lowest expected risk
    best_index = np.argmin(expected_risk)
    return candidates[best_index]

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate LLM-based consensus from responses.")
    parser.add_argument('--input_dir', type=str, required=True, help="Input dir path to the input JSON files.")
    parser.add_argument('--output_dir', type=str, required=True, help="Output dir path to the output JSON files.")
    args = parser.parse_args()

    api = os.getenv("OPENAI_API_KEY")

    for filename in tqdm(os.listdir(args.input_dir)):
            file_path = os.path.join(args.input_dir, filename)

            with open(file_path, 'r') as json_file:
                data = json.load(json_file)  

            for topic in tqdm(data):
                best_candidate = minimum_bayes_risk(topic["Responses"])
                topic["Responses"] = [best_candidate]
            
            res_file_path = os.path.join(args.output_dir, filename)
            with open(res_file_path, 'w') as json_file:
                json.dump(data, json_file, indent=4)
                    
            print("Done ", filename)